{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd \n",
    "import matplotlib.pyplot as plt\n",
    "import tensorflow as tf \n",
    "from tensorflow.keras import models,layers,losses,metrics,callbacks "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据准备"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'svg'\n",
    "\n",
    "df = pd.read_csv(\"./data/covid-19.csv\",sep = \"\\t\")\n",
    "df.plot(x = \"date\",y = [\"confirmed_num\",\"cured_num\",\"dead_num\"],figsize=(10,6))\n",
    "plt.xticks(rotation=60)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dfdata = df.set_index(\"date\")\n",
    "dfdiff = dfdata.diff(periods=1).dropna()\n",
    "dfdiff = dfdiff.reset_index(\"date\")\n",
    "\n",
    "dfdiff.plot(x = \"date\",y = [\"confirmed_num\",\"cured_num\",\"dead_num\"],figsize=(10,6))\n",
    "plt.xticks(rotation=60)\n",
    "dfdiff = dfdiff.drop(\"date\",axis = 1).astype(\"float32\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#用某日前8天窗口数据作为输入预测该日数据\n",
    "WINDOW_SIZE = 8\n",
    "\n",
    "def batch_dataset(dataset):\n",
    "    dataset_batched = dataset.batch(WINDOW_SIZE,drop_remainder=True)\n",
    "    return dataset_batched\n",
    "\n",
    "ds_data = tf.data.Dataset.from_tensor_slices(tf.constant(dfdiff.values,dtype = tf.float32)) \\\n",
    "   .window(WINDOW_SIZE,shift=1).flat_map(batch_dataset)\n",
    "\n",
    "ds_label = tf.data.Dataset.from_tensor_slices(\n",
    "    tf.constant(dfdiff.values[WINDOW_SIZE:],dtype = tf.float32))\n",
    "\n",
    "#数据较小，可以将全部训练数据放入到一个batch中，提升性能\n",
    "ds_train = tf.data.Dataset.zip((ds_data,ds_label)).batch(38).cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#考虑到新增确诊，新增治愈，新增死亡人数数据不可能小于0，设计如下结构\n",
    "class Block(layers.Layer):\n",
    "    def __init__(self, **kwargs):\n",
    "        super(Block, self).__init__(**kwargs)\n",
    "    \n",
    "    def call(self, x_input,x):\n",
    "        x_out = tf.maximum((1+x)*x_input[:,-1,:],0.0)\n",
    "        return x_out\n",
    "    \n",
    "    def get_config(self):  \n",
    "        config = super(Block, self).get_config()\n",
    "        return config\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.keras.backend.clear_session()\n",
    "x_input = layers.Input(shape = (None,3),dtype = tf.float32)\n",
    "x = layers.LSTM(3,return_sequences = True,input_shape=(None,3))(x_input)\n",
    "x = layers.LSTM(3,return_sequences = True,input_shape=(None,3))(x)\n",
    "x = layers.LSTM(3,return_sequences = True,input_shape=(None,3))(x)\n",
    "x = layers.LSTM(3,input_shape=(None,3))(x)\n",
    "x = layers.Dense(3)(x)\n",
    "\n",
    "#考虑到新增确诊，新增治愈，新增死亡人数数据不可能小于0，设计如下结构\n",
    "#x = tf.maximum((1+x)*x_input[:,-1,:],0.0)\n",
    "x = Block()(x_input,x)\n",
    "model = models.Model(inputs = [x_input],outputs = [x])\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练模型\n",
    "注：循环神经网络调试较为困难，需要设置多个不同的学习率多次尝试，以取得较好的效果。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#自定义损失函数，考虑平方差和预测目标的比值\n",
    "class MSPE(losses.Loss):\n",
    "    def call(self,y_true,y_pred):\n",
    "        err_percent = (y_true - y_pred)**2/(tf.maximum(y_true**2,1e-7))\n",
    "        mean_err_percent = tf.reduce_mean(err_percent)\n",
    "        return mean_err_percent\n",
    "    \n",
    "    def get_config(self):\n",
    "        config = super(MSPE, self).get_config()\n",
    "        return config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import datetime\n",
    "\n",
    "optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)\n",
    "model.compile(optimizer=optimizer,loss=MSPE(name = \"MSPE\"))\n",
    "\n",
    "stamp = datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
    "logdir = os.path.join('data', 'autograph', stamp)\n",
    "\n",
    "## 在 Python3 下建议使用 pathlib 修正各操作系统的路径\n",
    "# from pathlib import Path\n",
    "# stamp = datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
    "# logdir = str(Path('./data/autograph/' + stamp))\n",
    "\n",
    "tb_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)\n",
    "#如果loss在100个epoch后没有提升，学习率减半。\n",
    "lr_callback = tf.keras.callbacks.ReduceLROnPlateau(monitor=\"loss\",factor = 0.5, patience = 100)\n",
    "#当loss在200个epoch后没有提升，则提前终止训练。\n",
    "stop_callback = tf.keras.callbacks.EarlyStopping(monitor = \"loss\", patience= 200)\n",
    "callbacks_list = [tb_callback,lr_callback,stop_callback]\n",
    "\n",
    "history = model.fit(ds_train,epochs=500,callbacks = callbacks_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 评估模型\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'svg'\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def plot_metric(history, metric):\n",
    "    train_metrics = history.history[metric]\n",
    "    epochs = range(1, len(train_metrics) + 1)\n",
    "    plt.plot(epochs, train_metrics, 'bo--')\n",
    "    plt.title('Training '+ metric)\n",
    "    plt.xlabel(\"Epochs\")\n",
    "    plt.ylabel(metric)\n",
    "    plt.legend([\"train_\"+metric])\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_metric(history,\"loss\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 使用模型\n",
    "此处我们使用模型预测疫情结束时间，即 新增确诊病例为0 的时间。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#使用dfresult记录现有数据以及此后预测的疫情数据\n",
    "dfresult = dfdiff[[\"confirmed_num\",\"cured_num\",\"dead_num\"]].copy()\n",
    "dfresult.tail()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#预测此后100天的新增走势,将其结果添加到dfresult中\n",
    "for i in range(100):\n",
    "    arr_predict = model.predict(tf.constant(tf.expand_dims(dfresult.values[-38:,:],axis = 0)))\n",
    "\n",
    "    dfpredict = pd.DataFrame(tf.cast(tf.floor(arr_predict),tf.float32).numpy(),\n",
    "                columns = dfresult.columns)\n",
    "    dfresult = dfresult.append(dfpredict,ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dfresult.query(\"confirmed_num==0\").head()\n",
    "\n",
    "# 第55天开始新增确诊降为0，第45天对应3月10日，也就是10天后，即预计3月20日新增确诊降为0\n",
    "# 注：该预测偏乐观"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dfresult.query(\"cured_num==0\").head()\n",
    "\n",
    "# 第164天开始新增治愈降为0，第45天对应3月10日，也就是大概4个月后，即7月10日左右全部治愈。\n",
    "# 注: 该预测偏悲观，并且存在问题，如果将每天新增治愈人数加起来，将超过累计确诊人数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dfresult.query(\"dead_num==0\").head()\n",
    "\n",
    "# 第60天开始，新增死亡降为0，第45天对应3月10日，也就是大概15天后，即20200325\n",
    "# 该预测较为合理"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
